from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

from .CSPdarknet import darknet53
from .densenet import _Transition, densenet121, densenet169, densenet201
from .ghostnet import ghostnet
from .mobilenet_v1 import mobilenet_v1
from .mobilenet_v2 import mobilenet_v2
from .mobilenet_v3 import mobilenet_v3
from .resnet import resnet50
from .vgg import vgg

#-----------------------------------------#
# 池化结构
#-----------------------------------------#
from nets.spp.denseaspp import DenseASPP
from nets.spp.other_spp import ASPPBN, PPM, SpatialPyramidPooling, ASPPNOBN
SppType = [SpatialPyramidPooling, DenseASPP, ASPPBN, PPM, ASPPNOBN]
PoolType = [nn.MaxPool2d, nn.AvgPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d]

#---------------------------------------------------#
#   注意力机制
#---------------------------------------------------#
from nets.attention import cbam_block, eca_block, se_block, CA_Block
attention_block = [se_block, cbam_block, eca_block, CA_Block]

class MobileNetV1(nn.Module):
    def __init__(self, pretrained = False, phi=0, add_feature_layer=0):
        super(MobileNetV1, self).__init__()

        self.model = mobilenet_v1(pretrained=pretrained)

        #----------------------------#
        # """注意力机制""" #
        self.phi = phi
        if self.phi >= 1:
            self.out0_att = attention_block[self.phi - 1](32)
            self.out1_att = attention_block[self.phi - 1](64)
            self.out2_att = attention_block[self.phi - 1](128)
            self.out3_att = attention_block[self.phi - 1](256)
            self.out4_att = attention_block[self.phi - 1](512)
            self.out5_att = attention_block[self.phi - 1](1024)
        #-------------------------------#

        #--------------------------------#
        #   增加浅表特征层
        #   in_filters      = [256, 512, 1024]
        #--------------------------------#
        self.add_feature_layer = add_feature_layer




    def forward(self, x):
        out0 = self.model.stage0(x)

        if self.phi >= 1:
            out0 = self.out0_att(out0)

        out1 = self.model.stage1(out0)

        if self.phi >= 1:
            out1 = self.out1_att(out1)

        out2 = self.model.stage2(out1)

        if self.phi >= 1:
            out2 = self.out2_att(out2)

        out3 = self.model.stage3(out2)

        if self.phi >= 1:
            out3 = self.out3_att(out3)

        out4 = self.model.stage4(out3)

        if self.phi >= 1:
            out4 = self.out4_att(out4)

        out5 = self.model.stage5(out4)

        if self.phi >= 1:
            out5 = self.out5_att(out5)

        if self.add_feature_layer >= 1:
            return out2, out3, out4, out5
        else:
            return out3, out4, out5

class MobileNetV2(nn.Module):
    def __init__(self, pretrained = False, phi=0, add_feature_layer=0):
        super(MobileNetV2, self).__init__()
        self.model = mobilenet_v2(pretrained=pretrained)
        # ----------------------------#
        # """注意力机制""" #
        self.phi = phi
        # --------------------------------#
        #   增加浅表特征层
        #   in_filters      = [256, 512, 1024]
        # --------------------------------#
        self.add_feature_layer = add_feature_layer
        if self.phi >= 1:
            self.out0_att = attention_block[self.phi - 1](16)
            self.out1_att = attention_block[self.phi - 1](24)
            self.out2_att = attention_block[self.phi - 1](32)
            self.out3_att = attention_block[self.phi - 1](64)
            self.out4_att = attention_block[self.phi - 1](96)
            self.out5_att = attention_block[self.phi - 1](160)
            self.out6_att = attention_block[self.phi - 1](320)
        # -------------------------------#

    def forward(self, x):
        out0 = self.model.features[:2](x)
        # out0 = self.att(16, self.phi, out0)
        if self.phi >= 1:
            out0 = self.out0_att(out0)

        out1 = self.model.features[2:4](out0)
        # out1 = self.att(16, self.phi, out1)
        if self.phi >= 1:
            out1 = self.out1_att(out1)

        out2 = self.model.features[4:7](out1)
        # out2 = self.att(16, self.phi, out2)
        if self.phi >= 1:
            out2 = self.out2_att(out2)

        out3 = self.model.features[7:11](out2)
        # out3 = self.att(16, self.phi, out3)
        if self.phi >= 1:
            out3 = self.out3_att(out3)

        out4 = self.model.features[11:14](out3)
        # out4 = self.att(16, self.phi, out4)
        if self.phi >= 1:
            out4 = self.out4_att(out4)

        out5 = self.model.features[14:17](out4)
        # out5 = self.att(16, self.phi, out5)
        if self.phi >= 1:
            out5 = self.out5_att(out5)

        out6 = self.model.features[17:18](out5)
        # out6 = self.att(16, self.phi, out6)
        if self.phi >= 1:
            out6 = self.out6_att(out6)

        # print(self.model.features)
        if self.add_feature_layer >= 1:
            return out1, out2, out4, out6
        else:
            return out2, out4, out6

    # 获取注意力类型
    # def att(channel, phi, out):
    #     if phi >= 1:
    #         return (attention_block[phi - 1](channel))(out)
    #     else:
    #         return

class MobileNetV3(nn.Module):
    def __init__(self, pretrained = False, phi=0,add_feature_layer=0):
        super(MobileNetV3, self).__init__()
        self.model = mobilenet_v3(pretrained=pretrained,phi=phi)
        # --------------------------------#
        #   增加浅表特征层
        #   in_filters      = [256, 512, 1024]
        # --------------------------------#
        self.add_feature_layer = add_feature_layer
    def forward(self, x):
        out2 = self.model.features[:4](x)
        out3 = self.model.features[4:7](out2)
        out4 = self.model.features[7:13](out3)
        out5 = self.model.features[13:16](out4)
        if self.add_feature_layer >= 1:
            return out2, out3, out4, out5
        else:
            return out3, out4, out5

class GhostNet(nn.Module):
    def __init__(self, pretrained=True, phi=0,add_feature_layer=0):
        super(GhostNet, self).__init__()
        model = ghostnet(phi=phi)
        if pretrained:
            state_dict = torch.load("model_data/ghostnet_weights.pth")
            model.load_state_dict(state_dict)
        del model.global_pool
        del model.conv_head
        del model.act2
        del model.classifier
        del model.blocks[9]
        self.model = model
        # --------------------------------#
        #   增加浅表特征层
        #   in_filters      = [256, 512, 1024]
        # --------------------------------#
        self.add_feature_layer = add_feature_layer

    def forward(self, x):
        x = self.model.conv_stem(x)
        x = self.model.bn1(x)
        x = self.model.act1(x)
        feature_maps = []

        # in_filters = [40, 112, 160]
        # print(self.model.blocks)
        for idx, block in enumerate(self.model.blocks):
            x = block(x)
            if idx in [2,4,6,8]:
                feature_maps.append(x)

        if self.add_feature_layer >= 1:
            return feature_maps
        else:
            return feature_maps[1:]

class CSPdarknet(nn.Module):
    def __init__(self, pretrained=False, phi=0,add_feature_layer=0):
        super(CSPdarknet, self).__init__()
        self.model = darknet53(pretrained)
        self.phi = phi
        # --------------------------------#
        #   增加浅表特征层
        #   in_filters      = [256, 512, 1024]
        # --------------------------------#
        self.add_feature_layer = add_feature_layer

        if self.phi >= 1:
            self.out1_att = attention_block[self.phi - 1](64)
            self.out2_att = attention_block[self.phi - 1](128)
            self.out3_att = attention_block[self.phi - 1](256)
            self.out4_att = attention_block[self.phi - 1](512)
            self.out5_att = attention_block[self.phi - 1](1024)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.stages[0](x)
        if self.phi >= 1:
            x = self.out1_att(x)
        x = self.model.stages[1](x)
        if self.phi >= 1:
            x = self.out2_att(x)
        out3 = self.model.stages[2](x)
        if self.phi >= 1:
            out3 = self.out3_att(out3)
        out4 = self.model.stages[3](out3)
        if self.phi >= 1:
            out4 = self.out4_att(out4)
        out5 = self.model.stages[4](out4)
        if self.phi >= 1:
            out5 = self.out5_att(out5)

        if self.add_feature_layer >= 1:
            return x, out3, out4, out5
        else:
            return out3, out4, out5


class VGG(nn.Module):
    def __init__(self, pretrained=False):
        super(VGG, self).__init__()
        self.model = vgg(pretrained)

    def forward(self, x):
        feat1 = self.model.features[  :5 ](x)
        feat2 = self.model.features[5 :10](feat1)
        feat3 = self.model.features[10:17](feat2)
        feat4 = self.model.features[17:24](feat3)
        feat5 = self.model.features[24:  ](feat4)
        return [feat3, feat4, feat5]

class Densenet(nn.Module):
    def __init__(self, backbone, pretrained=False):
        super(Densenet, self).__init__()
        densenet = {
            "densenet121" : densenet121, 
            "densenet169" : densenet169, 
            "densenet201" : densenet201
        }[backbone]
        model = densenet(pretrained)
        del model.classifier
        self.model = model

    def forward(self, x):
        feature_maps = []
        for block in self.model.features:
            if type(block)==_Transition:
                for _, subblock in enumerate(block):
                    x = subblock(x)
                    if type(subblock)==nn.Conv2d:
                        feature_maps.append(x)
            else:
                x = block(x)
        x = F.relu(x, inplace=True)
        feature_maps.append(x)
        return feature_maps[1:]

class ResNet(nn.Module):
    def __init__(self, pretrained=False):
        super(ResNet, self).__init__()
        self.model = resnet50(pretrained)

    def forward(self, x):
        x       = self.model.conv1(x)
        x       = self.model.bn1(x)
        feat1   = self.model.relu(x)

        x       = self.model.maxpool(feat1)
        feat2   = self.model.layer1(x)

        feat3   = self.model.layer2(feat2)
        feat4   = self.model.layer3(feat3)
        feat5   = self.model.layer4(feat4)
        return [feat3, feat4, feat5]

def conv2d(filter_in, filter_out, kernel_size, groups=1, stride=1):
    pad = (kernel_size - 1) // 2 if kernel_size else 0
    return nn.Sequential(OrderedDict([
        ("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, groups=groups, bias=False)),
        ("bn", nn.BatchNorm2d(filter_out)),
        ("relu", nn.ReLU6(inplace=True)),
    ]))

def conv_dw(filter_in, filter_out, stride = 1):
    return nn.Sequential(
        nn.Conv2d(filter_in, filter_in, 3, stride, 1, groups=filter_in, bias=False),
        nn.BatchNorm2d(filter_in),
        nn.ReLU6(inplace=True),

        nn.Conv2d(filter_in, filter_out, 1, 1, 0, bias=False),
        nn.BatchNorm2d(filter_out),
        nn.ReLU6(inplace=True),
    )


#---------------------------------------------------#
#   卷积 + 上采样
#---------------------------------------------------#
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Upsample, self).__init__()

        self.upsample = nn.Sequential(
            conv2d(in_channels, out_channels, 1),
            nn.Upsample(scale_factor=2, mode='nearest')
        )

    def forward(self, x,):
        x = self.upsample(x)
        return x

#---------------------------------------------------#
#   三次卷积块
#---------------------------------------------------#
def make_three_conv(filters_list, in_filters):
    m = nn.Sequential(
        conv2d(in_filters, filters_list[0], 1),
        conv_dw(filters_list[0], filters_list[1]),
        conv2d(filters_list[1], filters_list[0], 1),
    )
    return m

#---------------------------------------------------#
#   五次卷积块
#---------------------------------------------------#
def make_five_conv(filters_list, in_filters):
    m = nn.Sequential(
        conv2d(in_filters, filters_list[0], 1),
        conv_dw(filters_list[0], filters_list[1]),
        conv2d(filters_list[1], filters_list[0], 1),
        conv_dw(filters_list[0], filters_list[1]),
        conv2d(filters_list[1], filters_list[0], 1),
    )
    return m

#---------------------------------------------------#
#   最后获得yolov4的输出
#---------------------------------------------------#
def yolo_head(filters_list, in_filters):
    m = nn.Sequential(
        conv_dw(in_filters, filters_list[0]),
        
        nn.Conv2d(filters_list[0], filters_list[1], 1),
    )
    return m


#---------------------------------------------------#
#   yolo_body
#---------------------------------------------------#
class YoloBody(nn.Module):
    def __init__(self, anchors_mask, num_classes, backbone="mobilenetv1", pretrained=False, phi=0,
                 spp = 0, pool_5=0, pool_9=0, pool_13=0,
                 add_feature_layer = 0
                 ):
        super(YoloBody, self).__init__()

        #   注意力机制
        self.phi = phi

        #   池化结构
        self.spp = spp

        #   增加浅表特征层
        self.add_feature_layer = add_feature_layer

        #---------------------------------------------------#   
        #   生成mobilnet的主干模型，获得三个有效特征层。
        #---------------------------------------------------#
        if backbone == "mobilenetv1":
            #---------------------------------------------------#   
            #   52,52,256；26,26,512；13,13,1024
            #---------------------------------------------------#
            self.backbone   = MobileNetV1(pretrained=pretrained, phi=self.phi, add_feature_layer=self.add_feature_layer)
            in_filters      = [256, 512, 1024]
            add_infilter = 128
        elif backbone == "mobilenetv2":
            #---------------------------------------------------#   
            #   52,52,32；26,26,92；13,13,320
            #---------------------------------------------------#
            self.backbone   = MobileNetV2(pretrained=pretrained, phi=self.phi,add_feature_layer=self.add_feature_layer)
            in_filters      = [32, 96, 320]
            add_infilter = 24
        elif backbone == "mobilenetv3":
            #---------------------------------------------------#   
            #   52,52,40；26,26,112；13,13,160
            #---------------------------------------------------#
            self.backbone   = MobileNetV3(pretrained=pretrained, phi=self.phi,add_feature_layer=self.add_feature_layer)
            in_filters      = [40, 112, 160]
            add_infilter = 24
        elif backbone == "ghostnet":
            #---------------------------------------------------#   
            #   52,52,40；26,26,112；13,13,160
            #---------------------------------------------------#
            self.backbone   = GhostNet(pretrained=pretrained, phi=self.phi,add_feature_layer=self.add_feature_layer)
            in_filters      = [40, 112, 160]
            add_infilter = 24
        elif backbone == "CSPdarknet":
            # ---------------------------------------------------#
            #   52,52,256；26,26,512；13,13,1024
            # ---------------------------------------------------#
            self.backbone = CSPdarknet(pretrained=pretrained, phi=self.phi,add_feature_layer=self.add_feature_layer)
            in_filters    = [256, 512, 1024]
            add_infilter = 128
        elif backbone == "vgg":
            #---------------------------------------------------#   
            #   52,52,256；26,26,512；13,13,512
            #---------------------------------------------------#
            self.backbone   = VGG(pretrained=pretrained)
            in_filters      = [256, 512, 512]
        elif backbone in ["densenet121", "densenet169", "densenet201"]:
            #---------------------------------------------------#   
            #   52,52,256；26,26,512；13,13,1024
            #---------------------------------------------------#
            self.backbone   = Densenet(backbone, pretrained=pretrained)
            in_filters = {
                "densenet121" : [256, 512, 1024], 
                "densenet169" : [256, 640, 1664], 
                "densenet201" : [256, 896, 1920]
            }[backbone]
        elif backbone == "resnet50":
            #---------------------------------------------------#   
            #   52,52,512；26,26,1024；13,13,2048
            #---------------------------------------------------#
            self.backbone   = ResNet(pretrained=pretrained)
            in_filters      = [512, 1024, 2048]
        else:
            raise ValueError('Unsupported backbone - `{}`, Use mobilenetv1, mobilenetv2, mobilenetv3, ghostnet, vgg, densenet121, densenet169, densenet201, resnet50.'.format(backbone))

        self.conv1           = make_three_conv([512, 1024], in_filters[2])

        if self.spp == 0:
            self.SPP_ty         = SpatialPyramidPooling(PoolType=PoolType, pool_5=pool_5, pool_9=pool_9, pool_13=pool_13)
        else:
            self.SPP_ty         = SppType[self.spp]()
            # self.SPP_ty         = DenseASPP(class_num=2048)
        self.conv2           = make_three_conv([512, 1024], 2048)

        self.upsample1       = Upsample(512, 256)
        self.conv_for_P4     = conv2d(in_filters[1], 256,1)
        self.make_five_conv1 = make_five_conv([256, 512], 512)

        self.upsample2       = Upsample(256, 128)
        self.conv_for_P3     = conv2d(in_filters[0], 128,1)
        self.make_five_conv2 = make_five_conv([128, 256], 256)

        if self.add_feature_layer >= 1:
            self.upsample_add = Upsample(128, 64)
            self.conv_for_add = conv2d(add_infilter, 64, 1)
            self.make_five_conv_add1 = make_five_conv([64, 128], 128)
            self.down_sample_add = conv_dw(64, 128, stride = 2)
            self.make_five_conv_add2 = make_five_conv([128, 256], 256)

        # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
        self.yolo_head3      = yolo_head([256, len(anchors_mask[0]) * (5 + num_classes)], 128)

        self.down_sample1    = conv_dw(128, 256, stride = 2)
        self.make_five_conv3 = make_five_conv([256, 512], 512)

        # 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
        self.yolo_head2      = yolo_head([512, len(anchors_mask[1]) * (5 + num_classes)], 256)

        self.down_sample2    = conv_dw(256, 512, stride = 2)
        self.make_five_conv4 = make_five_conv([512, 1024], 1024)

        # 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
        self.yolo_head1      = yolo_head([1024, len(anchors_mask[2]) * (5 + num_classes)], 512)


    def forward(self, x):
        #  backbone
        if self.add_feature_layer >= 1:
            x3, x2, x1, x0 = self.backbone(x)
        else:
            x2, x1, x0 = self.backbone(x)

        # 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,2048 
        P5 = self.conv1(x0)
        # print(P5.shape)
        P5 = self.SPP_ty(P5)
        # 13,13,2048 -> 13,13,512 -> 13,13,1024 -> 13,13,512
        P5 = self.conv2(P5)

        # 13,13,512 -> 13,13,256 -> 26,26,256
        P5_upsample = self.upsample1(P5)
        # 26,26,512 -> 26,26,256
        P4 = self.conv_for_P4(x1)
        # 26,26,256 + 26,26,256 -> 26,26,512
        P4 = torch.cat([P4,P5_upsample],axis=1)
        # 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
        P4 = self.make_five_conv1(P4)

        # 26,26,256 -> 26,26,128 -> 52,52,128
        P4_upsample = self.upsample2(P4)
        # 52,52,256 -> 52,52,128
        P3 = self.conv_for_P3(x2)
        # 52,52,128 + 52,52,128 -> 52,52,256
        P3 = torch.cat([P3,P4_upsample],axis=1)
        # 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128
        P3 = self.make_five_conv2(P3)

        # 52,52,128 -> 104,104,64
        if self.add_feature_layer >= 1:
            # 52,52,128 -> 104,104,64
            P_upsample_add = self.upsample_add(P3)
            # 104,104,64 -> 104,104,64
            P_add = self.conv_for_add(x3)
            # 104,104,64 -> 104,104,128
            P_add = torch.cat([P_add, P_upsample_add], axis=1)
            # 104,104,128 -> 104,104,64
            P_add = self.make_five_conv_add1(P_add)
            # 104,104,64 -> 52,52,128
            P_down_sample_add = self.down_sample_add(P_add)
            # 52,52,128 + 52,52,128 -> 52,52,256
            P3 = torch.cat([P_down_sample_add, P3],axis=1)
            # 52,52,256 -> 52,52,128
            P3 = self.make_five_conv_add2(P3)


        # 52,52,128 -> 26,26,256
        P3_downsample = self.down_sample1(P3)
        # 26,26,256 + 26,26,256 -> 26,26,512
        P4 = torch.cat([P3_downsample,P4],axis=1)
        # 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
        P4 = self.make_five_conv3(P4)

        # 26,26,256 -> 13,13,512
        P4_downsample = self.down_sample2(P4)
        # 13,13,512 + 13,13,512 -> 13,13,1024
        P5 = torch.cat([P4_downsample,P5],axis=1)
        # 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512
        P5 = self.make_five_conv4(P5)

        #---------------------------------------------------#
        #   第三个特征层
        #   y3=(batch_size,75,52,52)
        #---------------------------------------------------#
        out2 = self.yolo_head3(P3)
        #---------------------------------------------------#
        #   第二个特征层
        #   y2=(batch_size,75,26,26)
        #---------------------------------------------------#
        out1 = self.yolo_head2(P4)
        #---------------------------------------------------#
        #   第一个特征层
        #   y1=(batch_size,75,13,13)
        #---------------------------------------------------#
        out0 = self.yolo_head1(P5)

        return out0, out1, out2

